{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Задание 1.1 - Метод К-ближайших соседей (K-neariest neighbor classifier)\n",
    "\n",
    "В первом задании вы реализуете один из простейших алгоритмов машинного обучения - классификатор на основе метода K-ближайших соседей.\n",
    "Мы применим его к задачам\n",
    "- бинарной классификации (то есть, только двум классам)\n",
    "- многоклассовой классификации (то есть, нескольким классам)\n",
    "\n",
    "Так как методу необходим гиперпараметр (hyperparameter) - количество соседей, мы выберем его на основе кросс-валидации (cross-validation).\n",
    "\n",
    "Наша основная задача - научиться пользоваться numpy и представлять вычисления в векторном виде, а также ознакомиться с основными метриками, важными для задачи классификации.\n",
    "\n",
    "Перед выполнением задания:\n",
    "- запустите файл `download_data.sh`, чтобы скачать данные, которые мы будем использовать для тренировки\n",
    "- установите все необходимые библиотеки, запустив `pip install -r requirements.txt` (если раньше не работали с `pip`, вам сюда - https://pip.pypa.io/en/stable/quickstart/)\n",
    "\n",
    "Если вы раньше не работали с numpy, вам может помочь tutorial. Например этот:  \n",
    "http://cs231n.github.io/python-numpy-tutorial/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import load_svhn\n",
    "from knn import KNN\n",
    "from metrics import binary_classification_metrics, multiclass_accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Загрузим и визуализируем данные\n",
    "\n",
    "В задании уже дана функция `load_svhn`, загружающая данные с диска. Она возвращает данные для тренировки и для тестирования как numpy arrays.\n",
    "\n",
    "Мы будем использовать цифры из датасета Street View House Numbers (SVHN, http://ufldl.stanford.edu/housenumbers/), чтобы решать задачу хоть сколько-нибудь сложнее MNIST."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, train_y, test_X, test_y = load_svhn(\"data\", max_train=1000, max_test=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_per_class = 5  # Number of samples per class to visualize\n",
    "plot_index = 1\n",
    "for example_index in range(samples_per_class):\n",
    "    for class_index in range(10):\n",
    "        plt.subplot(5, 10, plot_index)\n",
    "        image = train_X[train_y == class_index][example_index]\n",
    "        plt.imshow(image.astype(np.uint8))\n",
    "        plt.axis('off')\n",
    "        plot_index += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Сначала реализуем KNN для бинарной классификации\n",
    "\n",
    "В качестве задачи бинарной классификации мы натренируем модель, которая будет отличать цифру 0 от цифры 9."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, let's prepare the labels and the source data\n",
    "\n",
    "# Only select 0s and 9s\n",
    "binary_train_mask = (train_y == 0) | (train_y == 9)\n",
    "binary_train_X = train_X[binary_train_mask]\n",
    "binary_train_y = train_y[binary_train_mask] == 0\n",
    "\n",
    "binary_test_mask = (test_y == 0) | (test_y == 9)\n",
    "binary_test_X = test_X[binary_test_mask]\n",
    "binary_test_y = test_y[binary_test_mask] == 0\n",
    "\n",
    "# Reshape to 1-dimensional array [num_samples, 32*32*3]\n",
    "binary_train_X = binary_train_X.reshape(binary_train_X.shape[0], -1)\n",
    "binary_test_X = binary_test_X.reshape(binary_test_X.shape[0], -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the classifier and call fit to train the model\n",
    "# KNN just remembers all the data\n",
    "knn_classifier = KNN(k=1)\n",
    "knn_classifier.fit(binary_train_X, binary_train_y)"
   ]
  },
  {
   "attachments": {
    "image.png": {
     "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR0AAAA3CAYAAAAxKp6fAAARVElEQVR4Ae3dB7B9P1EH8C9iQVFUVAQdxYJgL2CvgGUULGDH3sWOYm+joDK2sYwdG8Io9sJYsPfekWLFig0L2Lvz+U32T37nf85997xb3n3vZWfu5JycZLPZJJvNZpObDBgcGBwYHBgcGBwYHBgcGBwYHBgcGBwYHNgDB267BxwDxeDAVeLAbZK8UpJ/SfI/SV4xyfMl+cerVMmLrMsQOhfJ/VH2KXLg3ZO8UZIvS/IiSYyRL0jyzCRPPEWCLxtNz3bZCB70Dg4ckAPPmeQlkvxskjsk+eIk35rkD5K8aBJa0IDBgcGBwYG9coBm8/Akj2pCxsT8F0nuOYTOfvg8NJ398HFguToc+L8kb5zkp5J4ft0mbP44yX2SjDGzY1sPdXFHBo7sV44Dz5Xkn5O8TJKnJfmYJC/dDMnfnuQ3rlyNj1yhYUg+MsNHcSfPgRdI8vdJHp/kf5P8UZKXT/IzSX7u5KkfBA4ODA5cSg5MJ+Pp+6Ws1CB6cGBwYHDgWnJg2HSuZbNf20rr734MxBViRv9czJmLq29zIZwDtuAAxg4YHLgOHHj2JB+Q5F0PUFkC552SPP0AuK8cyiF0rlyTjgpt4MCrJPnVJJwAgZ2pL2nPfVBaDlsOw/ILN+/kF0py5yR3SXKPJPdt3+T9zPYbGk/PyfE8ODA4kI9qyyvC4T+SvOo5eUIwPXeS90ny1CR/leR258Q1su2RA1TaU3CmMmOh5aIBDaVd4ssp0LQvnvR1g/NU64b/P9AJnt9N8jw7MsHRCUcl3m0DHvy4Drtg+HvutjcoMOm8A5Ya+kNJ7rihIY71iYr8uHZ6+FhlTsvBR+d4quM5vfy53fs0/WV6Vyd1uVsj2vt3nXDd9Mm/7ATPN3aTwXn5TmP6iYU6G4iWX2++h3LOS9+x8t2r1XVR2Vj6IJ5hjGPUfyf56pUaC69Og/yzmqPVsSq8VI5rCazfvyfJnZYSHTieo9nzt46uKOo928JVsAHM1c1S41TrxvmPQRnd4L2TOF2+C/x2kkcuaE3FB+XV8y5lnXJefZyQXaznktDBnG/rGoI3ZjXQWRWG0wD/5SQ/f1biI36n/tI0Pn+lAN0XiRqi53dpPIuNs6+Cj4BH3fx6tVr9TrluzlaZFAu+IsnL1ss5w29O8k/nzHttsvWDYFppHYa1H/xwC7cJWPofkuQzVgiqbfDumkZ9vjzJOyR58V2RnSO/8nt+10AVXnZQN/UQFlT96v0UQ6fJTajARV3OVj1Hex/BgTjQD4JpEb69SZI/a5b56fel9/dP8ltJ/m4pwWRdu4mGDShu9Wk6eOc6/V8n+ekkaJymvxXCA0T0da3y+4FaRda3pfeKP1bY063MOd5OaZZnrm7HonmbcpgO3qXrq6+R5PO2yTjSnJ8DvToMi/eXS8Im45Y0QodxbNvOo6PxXfjFSR6zyFs39dVM8qPtJra3a1dC/n47YPeMlVXR0Z0AflAr7w+T/EiSf+i2Mn+yw6kev5Tk3l3csR7RqvwanFOe3jXJmyZ5yaYRfW27KvP129WZbAYM89N8h6SfAd6WsHX6n7S+ILxf0wwsV6tOrvasJWPV9ZC07Qs3gzJ7zg+2tqGlm5i+e18FbMDzVklcmWEHDR/tgjFy/06zpW7IepRPxrOVi00hNLnYTJ/QB/bSD3V4lxW5LY0moBEg/rAV1dPpDPj3mOSxVnYbG3x+NI4nNccs6q2432sCb5J18VXH/uzWOPKjl3pM8Hx0w/kpM7kZEP+tcxCbSXIjCsPX/pZwiUcvAQ4nsNvBMa2E0NsnecKERwSmNHhTdSTAjwHv1YSdcjnUfV3b8fmIdvUDgVO0C7+mncZGmzr+WPf9vPTCq09t2w6Vdm15ymHrU1c/fdhNgfsEu1fGWMH92wYN26cJ5hFt3LGHmvAvwgRQtAld7cG1gN3rg9pkzhdJv9gElI6HbdP2r9A60hc1bNWJNIBLqnvQsDQhaabgTll5FNyDtHYzEO37o5tWJR6+723xv7mFMCi879jywEdFhgc+s0ftEswJHduW8ty9EM2EOvkXJjHot/297wJPCj3aeqHzahOho0w2NLT5fXAnoMwwtcX7lV184d536CoHSw906EA0YPSLd9eM+KnQsXPjO1CXH1/gh36grbaBF2zCa9s2kI7mex6wk0gAFP9t+avHvqAXOrR9huzbJ/mWJDR8NxMCvGGacC9zlS/clmcNzU4BWv61M7Rre0ZyvHmxhlncdKXk05lCRyYVMktRjz+9I1VH/5skT25x1Kw3aLenaaA5Dah8cnTMHhD7n80DVDw7kfLECx+b5G2TGIiWEzrPJkBzSVxlfUfDIw+jt5nilRcQlK3peRe+iya0Pq7RN02G2eiewlJ8n06agv5ZnDJ7vv1pZ4gXT93/kCQfmOQju2+FT8hN/9UX6OvTeXbROA1mCnhLU61O/k2dqm8ZwO1gbnt5Wp/+3TO6tC2NzmRgCX4W0DhqkpimhXPaDuK23WWd4tM/uYmY+CwhHpjkoU0Dmqbd9R1vLfN5RNN+3N1j+QzUyaTOjwtIS/ATwNrdZDAH6v6GWxrCtbGVx7/PIIKHPPBvGJ/TfadhP6Wz775eku9s5oD/6tKd+UjoqBxibZErCCj4LSb2HHG+C5c8OH0DBMkasNwqYFM6S+hI+1Itw59POp9Ot6l8EhxwYd8ES5132tELx1J8fccbOItHnjW+98pbYeWpUNrikU7ICa8mg0ojFO+qzW3AEvPXFwZp8RYet+f1MMdbdRCvPqDqVGGLvrFTZAZf4m2l60M8mePLXFyf7zzPbBUEjaUkMKvTSlYNqpZ3LsAPYLB/UmsvqwMCqISJM100ID5y6ohXtAwKQKVpaG4KKAJvdlPM5hf4/aZgsubg9w2dUNKu7KCP6dqCeeQtZ3gzbfMp/hvqkRkIGOjVGSy3SFbSsDrS3zZV1/JlCUqLWBJKS/l6raNwLKUVrzEMQsuRtTYOsxjotYoWdUugzu/Z1rW3RJ7xgH+WT0uA5vpVmjMbqCWUrueRtpgDywO/bQDOau8+PRotgwuUW4K64qahPPCVQKp3IRC65lOaNaAfffzKfP46Zok/25TtQnYaBS3kE2YG1TY4ltIUP3zHKysH0PcbmzeAj5v0JmFnu85qA/RawvVlNFS3CpbaXsI6i9bLg9ds/a/ujSYUCcvSzvoClD/XzuJu0EbToVaCfuak0gJC58ObtZrdpYdbkHSR1GGIqfnbAjwuvwYk+a+1ZwP/nRsNtJke5NEoJLsTv36M4IDNoJZ5LeqmgDAFGnIJ1IEqOS13KT16tkl7g+lLSCaNVcJecvips4BRmUftHJyFv8+zlFZ8v/R5rSTf3zKio9b0Pa56XsJZ3yuEZxuAzxWhc8JxLj+eVX+e+75N3Ke25dVr7wHXpvLQ6v+1eMvbQADiLG2NgTI5EEx2CyttSzobbMunTe1Ufkq9PCgNys4wgePYCI0bjTZvNuFDqDFnp5rGdqN9GFQ9MMaq9Os06z2VUgGQCoHOwsC6dFbFdzPaJ7b0fQB3bbWxJxAotA4u6LQOhH9sl6EMxbbT5Z2Cda+y5LNzZUmgLsUE8VNDMvrsvmDoHM5pGft8V7a1e5VrRuGsKL5AHdDtR+ixgZhRrK3Fmc2Osd2PpuKjrVLasO1cu4VF39SQzMBtVgbqaKbs6ybeu2WdAb0tTHFsyrcm7RwerhcmzunmyVzatXG9IVlePDKBGHt2dw1idlKToX5suYTfhLzxwW3hGGD1QAvj4oJGQk/5llOUFKf0yQiG9gfPEDQ1JKtX9SUOw7epijOaueDIoHi/ViBDprUmY2PZeqqMTY1LAFgTbkrD3+Rt2slcW/QG2AOaQKsydHbpCLy55ZoBqPNqKPYZxxysOdVh0xINE8sLtco6VojfxReDdxMw2H5a8+OwM0YFV1+D+dCANpMCw6pB8Atty9QA8TwH8uhgoOrYXncKzuJTj3xN2j6fZzZMk4DwGP/kSRsnpK0gfqVpDYSNpUwJIwZkyoA0xzpeQevitU+40HAJH+3Jd4kwcg5TnyBcaC5ngTaxuWNpRqDd1DcMCJ1GCDyTbD3IQNOxo3FT5i4Re5Ab1GgiPcBbmo6Zu8oSX2X26T2T9nam5oROpZUXLWilGnpXvspONR3CyfZkraULxzFCdPWajgZw3qeve6/psJ35hk/q16c7Br3KUK4fGvBWaCscb3tNB21f1bkheO/dA4peuNZqOpX3kKFdU33GhsqhYKrp2AHER1qkvlvjoS8fH02+ljeWV9IcC5StPMtrdNICC+zs8i3Ct6k/01TTkUe709i5B9wW4gLrQZKs1oWee2u5jISAPAQKRombAq2FY1hPpHTTtCqirPpN8XhnJFbe3NZepZcfLrT69eX05XrW0LZ9e5tF4Tl0WPXt6Svald3TWu/yaIfKe2gap/iVWzTgreceqi5T+vr4Sq8dCX2hSeSYA6homAstG2nU3BHYMHeBaRsu4cIDOz+WcjQLvO3HXuXjpW6wslfaWi++1vdDhvommggRQNMBxj0bkxUNDdwkchboH1wl2ITg3Bo4fvEV4SHLiZAGsrQ215AMgIy6gHejZQHhgQAaD2eyTUw0s9q6rEo3VBsDa3HlsEcpx5bgl7ZyypXfjHERoK48PHU4wFXBTkvxwBLX0gXdfs6vccmv7y3bhQUmHLytzQJbuNRv9PlZmvQ2nd4j2Xe7UJxCed6aJbXtRddNn2Dfs4zdlRbt6twW/7Y5KE3HdwdN2TWNA0uoJXuNyd1yRnqD/ZhgDKOZLce49Ww7H594KNPSa+e7p2tO03G0h5Z/w01lTUWsM631DAiwqZE4/9kFcLETvweGKGs6v4J+W7bi+lA51DHq/LagIUnkvhy2KR3i65vT00XZc/DLjNYD4SheXfnDmAn6E/3sVKcCZr5qw6K56uO9NLKil82t0qkf/5wC7/Wt4o4dEqLf1zYjuPqjaRcwaZTH9hIedabl4WV5/yt3aTscDz+0IVulISwRsCK+dqAZgQuML/RSGIypOZrUcQqEp13wbbSiad7V71Qq6uF5Ya4C58HlHAtD+UWCujhvU5qO2e2itK5D8MGuWu1yqiMVfF/tt2960cUmxe2C8NkFTNzOKhpQJuUl4BZSmmCf5lR51NO45tnJBUvHq1avNTw46bRXqWEuU11o4FwxDBB014+w9Kv3CiuuvgsZ/W1UOCVu9uc7Nd086TsfXNcFNtZ1zfLqujDsmPXcVaU/Jq1nlXVZ6sK3hH3Jdi+frSkYMEt14eRmd9Zvegqc5nTD8W2KsL0v4VxIfqmjr1NdL3VDDeIPywHChM9J2Z4MjH3+rtJS+aAtsVENOmjJA/ngwPE5YBfVjyG312j656JqLq6+zYX8ksYMP8eZETc4MDgwODA4MDgwOHBxHNjFrrnJcHxxNTrxklnhBwwOXFcOcOdwxmitd7T0nFcdbB4mipW9ZwidlQwbya8UB3jUu950jS2GkOHBbLeKf86AwYHBgcGBrThAeFhardVyIJfH9bq83YemsxW7n5VoaDrP4sV4uj4cICg+uZ19Ky9hY0H80q++45Jt97kjANeHgzvUdBcj2g7FjqyDAxfKAUc0HFj1rwf3affXuC3zrCtPnJdy4LUHQmrN8qzPO54HBwYHrgEHaCwukHMEwgHGOhtYmkxpOkvvxSLn6MbyqrixIhyazgpmjaRXggMcAx+XxJ8uOvDpjmKChkF5k1GZEKprU64EIy6qEhg5YHDgunHAZOvfPpy/Yttx4v+s0+YEU90PJX+lZ1Qe42hFDxqazgpmjaRXhgNsMPq+i6poOq6wFfb/gDCtbNluhASVi/Ut0dwb5S7xY//P/JS+8T44MDhw4hy4U/tnjbVaCqFDYMnnR9M5z7b7ibNnkDc4MDgwODA4MDgwODA4MDgwODA4MDgwOHDiHPh/2s5bvlBRs4UAAAAASUVORK5CYII="
    }
   },
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Пришло время написать код! \n",
    "\n",
    "Последовательно реализуйте функции `compute_distances_two_loops`, `compute_distances_one_loop` и `compute_distances_no_loops`\n",
    "в файле `knn.py`.\n",
    "\n",
    "Эти функции строят массив расстояний между всеми векторами в тестовом наборе и в тренировочном наборе.  \n",
    "В результате они должны построить массив размера `(num_test, num_train)`, где координата `[i][j]` соотвествует расстоянию между i-м вектором в test (`test[i]`) и j-м вектором в train (`train[j]`).\n",
    "\n",
    "**Обратите внимание** Для простоты реализации мы будем использовать в качестве расстояния меру L1 (ее еще называют [Manhattan distance](https://ru.wikipedia.org/wiki/%D0%A0%D0%B0%D1%81%D1%81%D1%82%D0%BE%D1%8F%D0%BD%D0%B8%D0%B5_%D0%B3%D0%BE%D1%80%D0%BE%D0%B4%D1%81%D0%BA%D0%B8%D1%85_%D0%BA%D0%B2%D0%B0%D1%80%D1%82%D0%B0%D0%BB%D0%BE%D0%B2)).\n",
    "\n",
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: implement compute_distances_two_loops in knn.py\n",
    "dists = knn_classifier.compute_distances_two_loops(binary_test_X)\n",
    "assert np.isclose(dists[0, 10], np.sum(np.abs(binary_test_X[0] - binary_train_X[10])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: implement compute_distances_one_loop in knn.py\n",
    "dists = knn_classifier.compute_distances_one_loop(binary_test_X)\n",
    "assert np.isclose(dists[0, 10], np.sum(np.abs(binary_test_X[0] - binary_train_X[10])))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: implement compute_distances_no_loops in knn.py\n",
    "dists = knn_classifier.compute_distances_no_loops(binary_test_X)\n",
    "assert np.isclose(dists[0, 10], np.sum(np.abs(binary_test_X[0] - binary_train_X[10])))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets look at the performance difference\n",
    "%timeit knn_classifier.compute_distances_two_loops(binary_test_X)\n",
    "%timeit knn_classifier.compute_distances_one_loop(binary_test_X)\n",
    "%timeit knn_classifier.compute_distances_no_loops(binary_test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: implement predict_labels_binary in knn.py\n",
    "prediction = knn_classifier.predict(binary_test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: implement binary_classification_metrics in metrics.py\n",
    "precision, recall, f1, accuracy = binary_classification_metrics(prediction, binary_test_y)\n",
    "print(\"KNN with k = %s\" % knn_classifier.k)\n",
    "print(\"Accuracy: %4.2f, Precision: %4.2f, Recall: %4.2f, F1: %4.2f\" % (accuracy, precision, recall, f1)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's put everything together and run KNN with k=3 and see how we do\n",
    "knn_classifier_3 = KNN(k=3)\n",
    "knn_classifier_3.fit(binary_train_X, binary_train_y)\n",
    "prediction = knn_classifier_3.predict(binary_test_X)\n",
    "\n",
    "precision, recall, f1, accuracy = binary_classification_metrics(prediction, binary_test_y)\n",
    "print(\"KNN with k = %s\" % knn_classifier_3.k)\n",
    "print(\"Accuracy: %4.2f, Precision: %4.2f, Recall: %4.2f, F1: %4.2f\" % (accuracy, precision, recall, f1)) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Кросс-валидация (cross-validation)\n",
    "\n",
    "Попробуем найти лучшее значение параметра k для алгоритма KNN! \n",
    "\n",
    "Для этого мы воспользуемся k-fold cross-validation (https://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation). Мы разделим тренировочные данные на 5 фолдов (folds), и по очереди будем использовать каждый из них в качестве проверочных данных (validation data), а остальные -- в качестве тренировочных (training data).\n",
    "\n",
    "В качестве финальной оценки эффективности k мы усредним значения F1 score на всех фолдах.\n",
    "После этого мы просто выберем значение k с лучшим значением метрики.\n",
    "\n",
    "*Бонус*: есть ли другие варианты агрегировать F1 score по всем фолдам? Напишите плюсы и минусы в клетке ниже."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the best k using cross-validation based on F1 score\n",
    "num_folds = 5\n",
    "train_folds_X = []\n",
    "train_folds_y = []\n",
    "\n",
    "# TODO: split the training data in 5 folds and store them in train_folds_X/train_folds_y\n",
    "\n",
    "k_choices = [1, 2, 3, 5, 8, 10, 15, 20, 25, 50]\n",
    "k_to_f1 = {}  # dict mapping k values to mean F1 scores (int -> float)\n",
    "\n",
    "for k in k_choices:\n",
    "    # TODO: perform cross-validation\n",
    "    # Go through every fold and use it for testing and all other folds for training\n",
    "    # Perform training and produce F1 score metric on the validation dataset\n",
    "    # Average F1 from all the folds and write it into k_to_f1\n",
    "\n",
    "    pass\n",
    "\n",
    "for k in sorted(k_to_f1):\n",
    "    print('k = %d, f1 = %f' % (k, k_to_f1[k]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Проверим, как хорошо работает лучшее значение k на тестовых данных (test data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO Set the best k to the best value found by cross-validation\n",
    "best_k = 1\n",
    "\n",
    "best_knn_classifier = KNN(k=best_k)\n",
    "best_knn_classifier.fit(binary_train_X, binary_train_y)\n",
    "prediction = best_knn_classifier.predict(binary_test_X)\n",
    "\n",
    "precision, recall, f1, accuracy = binary_classification_metrics(prediction, binary_test_y)\n",
    "print(\"Best KNN with k = %s\" % best_k)\n",
    "print(\"Accuracy: %4.2f, Precision: %4.2f, Recall: %4.2f, F1: %4.2f\" % (accuracy, precision, recall, f1)) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Многоклассовая классификация (multi-class classification)\n",
    "\n",
    "Переходим к следующему этапу - классификации на каждую цифру."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now let's use all 10 classes\n",
    "train_X = train_X.reshape(train_X.shape[0], -1)\n",
    "test_X = test_X.reshape(test_X.shape[0], -1)\n",
    "\n",
    "knn_classifier = KNN(k=1)\n",
    "knn_classifier.fit(train_X, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: Implement predict_labels_multiclass\n",
    "predict = knn_classifier.predict(test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: Implement multiclass_accuracy\n",
    "accuracy = multiclass_accuracy(predict, test_y)\n",
    "print(\"Accuracy: %4.2f\" % accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Снова кросс-валидация. Теперь нашей основной метрикой стала точность (accuracy), и ее мы тоже будем усреднять по всем фолдам."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the best k using cross-validation based on accuracy\n",
    "num_folds = 5\n",
    "train_folds_X = []\n",
    "train_folds_y = []\n",
    "\n",
    "# TODO: split the training data in 5 folds and store them in train_folds_X/train_folds_y\n",
    "\n",
    "k_choices = [1, 2, 3, 5, 8, 10, 15, 20, 25, 50]\n",
    "k_to_accuracy = {}\n",
    "\n",
    "for k in k_choices:\n",
    "    # TODO: perform cross-validation\n",
    "    # Go through every fold and use it for testing and all other folds for validation\n",
    "    # Perform training and produce accuracy metric on the validation dataset\n",
    "    # Average accuracy from all the folds and write it into k_to_accuracy\n",
    "    pass\n",
    "\n",
    "for k in sorted(k_to_accuracy):\n",
    "    print('k = %d, accuracy = %f' % (k, k_to_accuracy[k]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Финальный тест - классификация на 10 классов на тестовой выборке (test data)\n",
    "\n",
    "Если все реализовано правильно, вы должны увидеть точность не менее **0.2**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO Set the best k as a best from computed\n",
    "best_k = 1\n",
    "\n",
    "best_knn_classifier = KNN(k=best_k)\n",
    "best_knn_classifier.fit(train_X, train_y)\n",
    "prediction = best_knn_classifier.predict(test_X)\n",
    "\n",
    "# Accuracy should be around 20%!\n",
    "accuracy = multiclass_accuracy(prediction, test_y)\n",
    "print(\"Accuracy: %4.2f\" % accuracy)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
